-
Notifications
You must be signed in to change notification settings - Fork 388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor(gry): refactor reward model #636
base: main
Are you sure you want to change the base?
Conversation
Codecov Report
@@ Coverage Diff @@
## main #636 +/- ##
==========================================
+ Coverage 82.06% 83.57% +1.51%
==========================================
Files 586 580 -6
Lines 47515 47428 -87
==========================================
+ Hits 38991 39640 +649
+ Misses 8524 7788 -736
Flags with carried forward coverage won't be shown. Click here to find out more.
|
@@ -32,3 +33,72 @@ def observation(self, obs): | |||
# print('vis_mask:' + vis_mask) | |||
image = grid.encode(vis_mask) | |||
return {**obs, "image": image} | |||
|
|||
|
|||
class ObsPlusPrevActRewWrapper(gym.Wrapper): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why add this wrapper here, rather than use the wrapper in ding/envs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because in ding/envs
, we use gym for the wrapper, but for minigrid we need gymnasium instead of gym. And in order to make a terrible influence on other env, I add this wrapper to minigrid wrapper.
@@ -10,16 +10,18 @@ | |||
), | |||
reward_model=dict( | |||
type='trex', | |||
exp_name='cartpole_trex_onppo_seed0', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why exp_name
here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in our original implementation, we used exp_name to build the tb logger. So it uses the whole config file. our new implementation only uses the reward model config, so I add this part to the reward model.
@@ -201,6 +133,7 @@ def load_expert_data(self) -> None: | |||
with open(self.cfg.data_path + '/expert_data.pkl', 'rb') as f: | |||
self.expert_data_loader: list = pickle.load(f) | |||
self.expert_data = self.concat_state_action_pairs(self.expert_data_loader) | |||
self.expert_data = torch.unbind(self.expert_data, dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why unbind here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because we re-use the concat_state_action_pair function, and its return is different from the original function in Gail. So I used unbind here.
max_train_iter: Optional[int] = int(1e10), | ||
max_env_step: Optional[int] = int(1e10), | ||
cooptrain_reward: Optional[bool] = True, | ||
pretrain_reward: Optional[bool] = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add comments for new arguments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
# update reward_model, when you want to train reward_model inloop | ||
if cooptrain_reward: | ||
reward_model.train() | ||
# clear buffer per fix iters to make sure replay buffer's data count isn't too few. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clear buffer per fixed iters to make sure the data for RM training is not too offpolicy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
@@ -108,11 +111,11 @@ def serial_pipeline_reward_model_offpolicy( | |||
# collect data for reward_model training | |||
reward_model.collect_data(new_data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add if if cooptrain_reward
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
try: | ||
serial_pipeline_reward_model_offpolicy(config, seed=0, max_train_iter=2) | ||
except Exception: | ||
assert False, "pipeline fail" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add finally branch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
@@ -0,0 +1,106 @@ | |||
from typing import Optional, List, Any |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
file name typo reword
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
@@ -22,44 +22,49 @@ | |||
stop_value=int(1e5), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove pitfall
and mnotezuma
config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
|
||
# train reward model | ||
serial_pipeline_reward_model_offpolicy(main_config, create_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wrong usage here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
@@ -22,7 +22,6 @@ | |||
action_bins_per_branch=2, # mean the action shape is 6, 2 discrete actions for each action dimension |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why modify this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It may be modified by format.sh, will I need to change it back?
@@ -24,6 +24,7 @@ | |||
update_per_collect=5, | |||
batch_size=64, | |||
learning_rate=0.001, | |||
learner=dict(hook=dict(save_ckpt_after_iter=100)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why add this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because when we do unit test at drex, we need to modify the learner.hook.save_ckpt_after_iter
, if we do not have this, the unit test will be failed, so I add this.
max_train_iter: Optional[int] = int(1e10), | ||
max_env_step: Optional[int] = int(1e10), | ||
cooptrain_reward: Optional[bool] = True, | ||
pretrain_reward: Optional[bool] = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pretrain_reward
-> pretrain_reward_model
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
model: Optional[torch.nn.Module] = None, | ||
max_train_iter: Optional[int] = int(1e10), | ||
max_env_step: Optional[int] = int(1e10), | ||
cooptrain_reward: Optional[bool] = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cooptrain_reward
-> joint_train_reward_model
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
self.tb_logger.add_scalar('icm_reward/action_accuracy', accuracy, self.train_cnt_icm) | ||
loss = self.reverse_scale * inverse_loss + forward_loss | ||
self.tb_logger.add_scalar('icm_reward/total_loss', loss, self.train_cnt_icm) | ||
inverse_loss, forward_loss, accuracy = self.reward_model.learn(data_states, data_next_states, data_actions) | ||
loss = self.reverse_scale * inverse_loss + forward_loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.reverse_scale
-> self.reverse_loss_weight
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
self.tb_logger.add_scalar('icm_reward/action_accuracy', accuracy, self.train_cnt_icm) | ||
loss = self.reverse_scale * inverse_loss + forward_loss | ||
self.tb_logger.add_scalar('icm_reward/total_loss', loss, self.train_cnt_icm) | ||
inverse_loss, forward_loss, accuracy = self.reward_model.learn(data_states, data_next_states, data_actions) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在这里 accuracy的含义是?增加注释,以及换一下变量名
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
item['reward'] = item['reward'] / self.cfg.extrinsic_reward_norm_max | ||
elif self.intrinsic_reward_type == 'assign': | ||
item['reward'] = icm_rew | ||
train_data_augmented = combine_intrinsic_exterinsic_reward(train_data_augmented, icm_reward, self.cfg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
icm_reward
-> normalized_icm_reward
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
self.only_use_last_five_frames = config.only_use_last_five_frames_for_icm_rnd | ||
|
||
def _train(self) -> None: | ||
def _train(self) -> torch.Tensor: | ||
# sample episode's timestep index | ||
train_index = np.random.randint(low=0, high=self.train_obs.shape[0], size=self.cfg.batch_size) | ||
|
||
train_obs: torch.Tensor = self.train_obs[train_index].to(self.device) # shape (self.cfg.batch_size, obs_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的: torch.Tensor或许可以去掉,在上面写上overview格式的注释
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里具体指的是什么呢,为什么写了注释之后就可以不控制返回参数的类型
""" | ||
states_data = [] | ||
actions_data = [] | ||
#check data(dict) has key obs and action |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
空格 使用 bash format.sh ding 格式化代码
def clear_data(self, iter: int) -> None: | ||
assert hasattr( | ||
self.cfg, 'clear_buffer_per_iters' | ||
), "Reward Model does not have clear_buffer_per_iters, Clear failed" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
报错,可以给出修改建议,例如你需要参考xxx, 实现xxx方法
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
type='rnd-ngu', | ||
), | ||
episodic_reward_model=dict( | ||
# means if using rescale trick to the last non-zero reward |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这段注释可以用gpt4优化一下语法
type='rnd-ngu', | ||
), | ||
episodic_reward_model=dict( | ||
# means if using rescale trick to the last non-zero reward |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这段注释可以用gpt4优化一下语法
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
优化完毕
Description
It is a draft pr used for refactoring the reward model
Things finished
Refactoring
New system Design
Pipeline
Check List